from typing import Any, Optional, NamedTuple, Iterable, Callable
import jax.numpy as jnp
import jax
import haiku as hk
import matplotlib.pyplot as plt
import math
import numpy as np
import optax
import functools
from tqdm import tqdm
import tensorflow_probability.substrates.jax as tfp

from regression import make_hetstat_mlp, make_sac_mlp, train_model

hk_init = hk.initializers
tfd = tfp.distributions


def main():

    plt.rc("text", usetex=True)
    plt.rc("font", family="serif", size=9)

    np.random.seed(1)

    fig, axs = plt.subplots(1, 3, figsize=(6.6, 1.3), sharex=True, sharey=True)
    fig_f, axs_f = plt.subplots(1, 4, figsize=(6.6, 3.3), sharex=True, sharey=True)

    def function(x):
        return 0.5 * np.sin(x - 0.2) - 0.3 * np.cos(3 * x + 0.5)

    def noisy_function(x):
        noise = np.abs(0.1 * np.sin(x)) * np.random.randn(*x.shape)
        return function(x) + noise

    n = 150
    s = np.linspace(-6, 6, n)
    a_clean = function(s)
    a_noisy = noisy_function(s)

    x = np.linspace(-20, 20, 1000)

    def kernel(x, y):
        return np.exp(-0.5 * (x - y.T) ** 2 / 0.01)

    K = kernel(s[:, None], s[:, None])
    K_ = kernel(x[:, None], s[:, None])
    y = K_ @ np.linalg.solve(K + 1e-3 * np.eye(n), a_clean)

    var = kernel(x[:, None], x[:, None]) - K_ @ np.linalg.solve(
        K + 1e-3 * np.eye(n), K_.T
    )
    var += np.diag(
        np.abs(0.1 * np.sin(x)) * (x > -6).astype(np.float) * (x < 6).astype(np.float)
    )
    std = np.sqrt(np.diag(var))

    upper = y + std
    lower = y - std
    axs[0].plot(x, y, "b")
    axs[0].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)

    # MLP (NLLH)
    network = hk.without_apply_rng(
        hk.transform(make_sac_mlp(1, [256, 256], faithful=False))
    )
    policy, params = train_model(network, s[:, None], a_noisy[:, None], faithful=False)
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs_f[0].plot(x, y, "b")
    axs_f[0].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    y = policy(s[:, None]).mode().squeeze()
    mlp_mse_n = ((a_noisy - y) ** 2).mean()

    # MLP (Faithful)
    network = hk.without_apply_rng(hk.transform(make_sac_mlp(1, [256, 256])))
    policy, params = train_model(network, s[:, None], a_noisy[:, None])
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[1].plot(x, y, "b")
    axs[1].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    axs_f[1].plot(x, y, "b")
    axs_f[1].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    y = policy(s[:, None]).mode().squeeze()
    mlp_mse_f = ((a_noisy - y) ** 2).mean()

    # HETSTAT (Faithful)
    network = hk.without_apply_rng(
        hk.transform(make_hetstat_mlp(1, [256, 256, 12, 256]))
    )
    policy, params = train_model(network, s[:, None], a_noisy[:, None])
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[2].plot(x, y, "b")
    axs[2].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    axs_f[3].plot(x, y, "b")
    axs_f[3].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    y = policy(s[:, None]).mode().squeeze()
    hetstat_mse_f = ((a_noisy - y) ** 2).mean()

    # HETSTAT NLLH
    network = hk.without_apply_rng(
        hk.transform(make_hetstat_mlp(1, [256, 256, 12, 256], faithful=False))
    )
    policy, params = train_model(network, s[:, None], a_noisy[:, None], faithful=False)
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs_f[2].plot(x, y, "b")
    axs_f[2].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)
    y = policy(s[:, None]).mode().squeeze()
    hetstat_mse_n = ((a_noisy - y) ** 2).mean()

    axs[0].set_title("Desired")
    axs[1].set_title("Heteroscedastic MLP")
    axs[2].set_title("Stationary Heteroscedastic MLP")

    axs[0].set_ylabel("$a$")

    axs_f[0].set_ylabel("$a$")
    axs_f[0].set_title(f"Heteroscedastic\n MLP\n (NLLH)\n Train MSE = {mlp_mse_n:.2f}")
    axs_f[1].set_title(
        f"Heteroscedastic\n MLP\n (Faithful)\n Train MSE = {mlp_mse_f:.2f}"
    )
    axs_f[2].set_title(
        f"Stationary\n Heteroscedastic\n MLP (NLLH)\n Train MSE = {hetstat_mse_n:.2f}"
    )
    axs_f[3].set_title(
        f"Stationary\n Heteroscedastic\n MLP (Faithful)\n Train MSE = {hetstat_mse_f:.2f}"
    )

    for ax in axs:
        ax.plot(s, a_noisy, "k.", markersize=1)
        ax.set_xlim(-20, 20)
        ax.set_xticklabels([])
        ax.set_xticks([])
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xlabel("$s$")

    for ax in axs_f:
        ax.plot(s, a_noisy, "k.", markersize=1)
        ax.set_xlim(-20, 20)
        ax.set_xticklabels([])
        ax.set_xticks([])
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xlabel("$s$")

    fig.tight_layout()
    fig.savefig("heteroscedastic_figure.pdf", bbox_inches="tight")

    fig_f.tight_layout()
    fig_f.savefig("faithful_figure.pdf", bbox_inches="tight")


if __name__ == "__main__":
    main()
    plt.show()
